/**
 * Copyright 2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
// clang-format off
#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_SSE_H_
#define MINDSPORE_NNACL_FP32_MATMUL_F32_SSE_H_

#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_sse_instructions.h"

#ifdef __cplusplus
extern "C" {
#endif
#pragma GCC push_options
#pragma GCC target("sse4.1")
#define MS_SIMD_INSTRUCTION MS_SIMD_SSE_INSTRUCTION
#define BLOCK_NUM 4
#define MS_SIMD_SSE

// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
static inline int64_t GemmIsNotPackSSE(int64_t index, const float *a, const float *b, float *c, const float *bias, int row,
  int deep, int act_type) {
  SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f);
  SIMD_F32 up_threshold = SIMD_MOV_F32(6);
  SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]);
  SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]);
  for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 a_data = SIMD_LD_F32(a + index);
    SIMD_F32 dst = SIMD_FMADD_F32(b_data16, a_data, bias_data16);
    if (act_type != 0) {
      dst = SIMD_MAX_F32(dst, down_threshold);
      if (act_type == 3) {
        dst = SIMD_MIN_F32(dst, up_threshold);
      }
    }
    SIMD_ST_F32(c + index, dst);
  }

  return index;
}

// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
static inline int64_t Row1Deep1GemmIsNotPackSSE(int64_t index, const float *a, const float *b, float *c, const float *bias, int col,
  int act_type) {
  SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f);
  SIMD_F32 up_threshold = SIMD_MOV_F32(6);
  SIMD_F32 vec_a = SIMD_MOV_F32(a[0]);
  if (act_type == 1) {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 vec_bias = SIMD_LD_F32(bias + index);
      SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias);
      SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold));  // relu
    }
  } else if (act_type == 3) {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 vec_bias = SIMD_LD_F32(bias + index);
      SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias);
      SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold));  // relue6
    }
  } else {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 vec_bias = SIMD_LD_F32(bias + index);
      SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias);
      SIMD_ST_F32(c + index, dst);  // no_act
    }
  }

  return index;
}

// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
static inline int64_t Row1Deep1NoBiasGemmIsNotPackSSE(int64_t index, const float *a, const float *b, float *c, const float *bias, int col,
  int act_type) {
  SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f);
  SIMD_F32 up_threshold = SIMD_MOV_F32(6);
  SIMD_F32 vec_a = SIMD_MOV_F32(a[0]);
  if (act_type == 1) {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b);
      SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold));  // relu
    }
  } else if (act_type == 3) {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b);
      SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold));  // relue6
    }
  } else {
    for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      SIMD_F32 vec_b = SIMD_LD_F32(b + index);
      SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b);
      SIMD_ST_F32(c + index, dst);  // no_act
    }
  }

  return index;
}

#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX)
static inline int64_t GemmIsNotPackOptimizeCoreSSE(int64_t index, const float *a, const float *b, int k, float *dst) {
  SIMD_F32 dst1 = SIMD_MOV_F32(0.0f);
  for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 weight = SIMD_LD_F32(b + index);
    SIMD_F32 a1 = SIMD_LD_F32(a + index);
    dst1 = SIMD_FMADD_F32(weight, a1, dst1);
  }
  *dst += SIMD_REDUCE_ADD_F32(dst1);
  return index;
}
#endif

static inline int64_t MatVecMulNoPackCoreSSE(int64_t oc_index, const float *a, const float *b, float *c, const float *bias,
  int act_type, int64_t depth, int64_t oc, int64_t col, int64_t inc_flag) {
  for (int64_t oc_max_size = oc - BLOCK_NUM; oc_index <= oc_max_size; oc_index += BLOCK_NUM) {
    SIMD_F32 out = (inc_flag & 0x1) == 0 ? SIMD_LD_F32(c + oc_index) : (bias == NULL ? SIMD_MOV_F32(0.0f) : SIMD_LD_F32(bias + oc_index));
    for (int64_t k = 0; k < depth; ++k) {
      SIMD_F32 left = SIMD_MOV_F32(a[k]);
      SIMD_F32 right = SIMD_LD_F32(b + oc_index + k * col);
      out = SIMD_FMADD_F32(left, right, out);
    }
    if ((inc_flag & 0x2) != 0 && act_type != 0) {
      out = SIMD_MAX_F32(out, SIMD_MOV_F32(0.0f));
      if (act_type == 0x3) {
        out = SIMD_MIN_F32(out, SIMD_MOV_F32(6.0f));
      }
    }
    SIMD_ST_F32(c + oc_index, out);
  }
  return oc_index;
}

#undef MS_SIMD_INSTRUCTION
#undef BLOCK_NUM
#pragma GCC pop_options
#undef MS_SIMD_SSE
#ifdef __cplusplus
}
#endif
#endif
